package com.atguigu.flink.chapter11.function;

import com.atguigu.flink.bean.WaterSensor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.functions.AggregateFunction;

/**
 * @Author lizhenchao@atguigu.cn
 * @Date 2021/12/22 10:19
 */
public class Flink03_Aggregate {
    public static void main(String[] args) {
        Configuration conf = new Configuration();
        conf.setInteger("rest.port", 20000);
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(conf);
        env.setParallelism(1);
        
        DataStreamSource<WaterSensor> waterSensorStream =
            env.fromElements(new WaterSensor("sensor_1", 1000L, 10),
                             new WaterSensor("sensor_1", 2000L, 20),
                             new WaterSensor("sensor_2", 3000L, 30),
                             new WaterSensor("sensor_1", 4000L, 40),
                             new WaterSensor("sensor_1", 5000L, 50),
                             new WaterSensor("sensor_2", 6000L, 60));
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
        Table table = tEnv.fromDataStream(waterSensorStream);
        tEnv.createTemporaryView("sensor", table);
        // 1.在table API中使用
        // 1.1 内联的方式使用
        /*table
            .groupBy($("id"))
            .select($("id"), call(MyAvg.class, $("vc")).as("avg_vc"))
            .execute()
            .print();*/
    
        /*table
            .groupBy($("id"))
            .aggregate(call(MyAvg.class, $("vc")).as("avg_vc"))
            .select($("id"), $("avg_vc"))
            .execute()
            .print();*/
        
        // 1.2 函数先注册, 再使用
      
        // 2. 在sql中使用
        tEnv.createTemporaryFunction("my_avg", MyAvg.class);
        
        tEnv.sqlQuery("select" +
                          " id, " +
                          " my_avg(vc) avg_vc " +
                          "from sensor " +
                          "group by id")
            .execute()
            .print();
        
    }
    
    public static class MyAvg extends AggregateFunction<Double, Avg> {
        // 返回最终的聚合结果
        @Override
        public Double getValue(Avg acc) {
            return acc.avg();
        }
        // 初始化累加器
        @Override
        public Avg createAccumulator() {
            return new Avg();
        }
        // 对水位值进行累加:
        // 第一个参数必须是累加器. 后面的key多个也可以一个, 表示参与累加器的元素
        public void accumulate(Avg acc, Integer vc){
            acc.sum += vc;
            acc.count++;
        }
    }
    
    public static class Avg{
        public Integer sum = 0;
        public Long count = 0L;
    
        public Double avg() {
            return sum * 1.0 / count;
        }
    }
    
 
}
